{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# CurvLearn Tutorial\n",
    "In this  tutorial, you will learn how to build a non-Euclidean binary classification model, including\n",
    "- define manifold and riemannian tensors.\n",
    "- build non-Euclidean models from manifold operations.\n",
    "- define loss function and apply riemannian optimization.\n",
    "\n",
    "Let's start!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting curvlearn\n",
      "  Using cached curvlearn-0.1.0-py3-none-any.whl (29 kB)\n",
      "Requirement already satisfied: numpy<2.0.0,>=1.16.5 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from curvlearn) (1.20.3)\n",
      "Requirement already satisfied: tensorflow<2.0.0,>=1.15.0 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from curvlearn) (1.15.0)\n",
      "Requirement already satisfied: google-pasta>=0.1.6 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (0.2.0)\n",
      "Requirement already satisfied: absl-py>=0.7.0 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (0.13.0)\n",
      "Requirement already satisfied: tensorboard<1.16.0,>=1.15.0 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (1.15.0)\n",
      "Requirement already satisfied: astor>=0.6.0 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (0.8.1)\n",
      "Requirement already satisfied: six>=1.10.0 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (1.16.0)\n",
      "Requirement already satisfied: opt-einsum>=2.3.2 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (3.3.0)\n",
      "Requirement already satisfied: wrapt>=1.11.1 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (1.12.1)\n",
      "Requirement already satisfied: wheel>=0.26 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (0.36.2)\n",
      "Requirement already satisfied: termcolor>=1.1.0 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (1.1.0)\n",
      "Requirement already satisfied: tensorflow-estimator==1.15.1 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (1.15.1)\n",
      "Requirement already satisfied: gast==0.2.2 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (0.2.2)\n",
      "Requirement already satisfied: grpcio>=1.8.6 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (1.36.1)\n",
      "Requirement already satisfied: keras-applications>=1.0.8 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (1.0.8)\n",
      "Requirement already satisfied: keras-preprocessing>=1.0.5 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (1.1.2)\n",
      "Requirement already satisfied: protobuf>=3.6.1 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorflow<2.0.0,>=1.15.0->curvlearn) (3.17.2)\n",
      "Requirement already satisfied: h5py in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from keras-applications>=1.0.8->tensorflow<2.0.0,>=1.15.0->curvlearn) (3.2.1)\n",
      "Requirement already satisfied: werkzeug>=0.11.15 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0.0,>=1.15.0->curvlearn) (0.16.1)\n",
      "Requirement already satisfied: setuptools>=41.0.0 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0.0,>=1.15.0->curvlearn) (52.0.0.post20210125)\n",
      "Requirement already satisfied: markdown>=2.6.8 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0.0,>=1.15.0->curvlearn) (3.3.4)\n",
      "Requirement already satisfied: importlib-metadata in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow<2.0.0,>=1.15.0->curvlearn) (1.7.0)\n",
      "Requirement already satisfied: cached-property in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from h5py->keras-applications>=1.0.8->tensorflow<2.0.0,>=1.15.0->curvlearn) (1.5.2)\n",
      "Requirement already satisfied: zipp>=0.5 in /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages (from importlib-metadata->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow<2.0.0,>=1.15.0->curvlearn) (3.5.0)\n",
      "Installing collected packages: curvlearn\n",
      "Successfully installed curvlearn-0.1.0\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install curvlearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 500\n",
    "batch_size = 1024\n",
    "log_steps = 100\n",
    "learning_rate = 1e-3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "CurvLearn now supports the following manifolds\n",
    "- Constant curvature manifolds\n",
    "    - ```curvlearn.manifolds.Euclidean``` - Euclidean space with zero curvature.\n",
    "    - ```curvlearn.manifolds.Stereographic``` - Constant curvature stereographic projection model. The curvature can be positive, negative or zero.\n",
    "    - ```curvlearn.manifolds.PoincareBall``` - The stereographic projection of the Lorentz model with negative curvature.\n",
    "    - ```curvlearn.manifolds.ProjectedSphere``` - The stereographic projection of the sphere model with positive curvature.\n",
    "- Mixed curvature manifolds\n",
    "    - ```curvlearn.manifolds.Product``` - Mixed-curvature space consists of multiple manifolds with different curvatures.\n",
    "\n",
    "In this tutorial, we use the stereographic model with trainable curvature. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /var/folders/yb/bnmcr5gd40bc3s00cshy0gr80000gp/T/ipykernel_15969/2369927432.py:4: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "Stereographic\n"
     ]
    }
   ],
   "source": [
    "from curvlearn.manifolds import Stereographic\n",
    "\n",
    "manifold = Stereographic()\n",
    "curvature = tf.get_variable(name=\"curvature\", initializer=tf.constant(0.0, dtype=manifold.dtype), trainable=True)\n",
    "\n",
    "print(manifold.name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate random binary classification dataset.\n",
    "1 sprase feature and 8 dense features are used to predict the 0/1 label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /var/folders/yb/bnmcr5gd40bc3s00cshy0gr80000gp/T/ipykernel_15969/3446372793.py:16: The name tf.data.make_one_shot_iterator is deprecated. Please use tf.compat.v1.data.make_one_shot_iterator instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "global_step = tf.get_variable(name='global_step',initializer=tf.constant(0), trainable=False)\n",
    "\n",
    "dense = np.random.rand(10000, 8)\n",
    "sparse = np.random.randint(0, 1000, [10000, 1])\n",
    "labels = np.random.choice([0, 1], size=10000, replace=True)\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensor_slices(\n",
    "    {\n",
    "        'dense': tf.cast(dense, tf.float32),\n",
    "        'sparse': tf.cast(sparse, tf.int32),\n",
    "        'labels': tf.cast(labels, tf.float32)\n",
    "    }\n",
    ")\n",
    "dataset = dataset.shuffle(batch_size * 10).batch(batch_size, drop_remainder=False).repeat(epochs)\n",
    "\n",
    "iterator = tf.data.make_one_shot_iterator(dataset)\n",
    "batch = iterator.get_next()\n",
    "dense, sparse, labels = batch['dense'], batch['sparse'], batch['labels']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define tensors in the specific manifold can be simply realized through the wrapper function `manifold.variable`.\n",
    "According to the variable name, tensors are optimized in different ways.\n",
    "- \"*RiemannianParameter*\" is contained in the variable name: the variable is a riemannian tensor, and should be optimized by riemannian optimizers.\n",
    "- Otherwise: the variable is an euclidean(tangent) tensor and is projected into the manifold. In this case, riemannian optimizers behave equivalently to vanilla euclidean optimizers.\n",
    "\n",
    "Here we optimize dense embedding in euclidean space and sparse embedding in curved space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages/tensorflow_core/python/ops/clip_ops.py:172: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "embedding_table = tf.get_variable(\n",
    "    name='RiemannianParameter/embedding',\n",
    "    shape=(1000, 8),\n",
    "    dtype=manifold.dtype,\n",
    "    initializer=tf.truncated_normal_initializer(0.001)\n",
    ")\n",
    "embedding_table = manifold.variable(embedding_table, c=curvature)\n",
    "sparse_embedding = tf.squeeze(tf.nn.embedding_lookup(embedding_table, sparse), axis=1)\n",
    "dense_embedding = manifold.variable(dense, c=curvature)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Building riemannian neural networks requires replacing euclidean tensor operations with manifold operations.\n",
    "\n",
    "CurvLearn now supports the following basic operations.\n",
    "- ```variable(t, c)``` - Defines a riemannian variable from manifold or tangent space at origin according to its name.\n",
    "- ```to_manifold(t, c, base)``` - Converts a tensor ```t``` in the tangent space of ```base``` point to the manifold.\n",
    "- ```to_tangent(t, c, base)``` - Converts a tensor ```t``` in the manifold to the tangent space of ```base``` point.\n",
    "- ```weight_sum(tensor_list, a, c)``` - Computes the sum of tensor list ```tensor_list``` with weight list ```a```.\n",
    "- ```mean(t, c, axis)``` - Computes the average of elements along ```axis``` dimension of a tensor ```t```.\n",
    "- ```sum(t, c, axis)``` - Computes the sum of elements along ```axis``` dimension of a tensor ```t```.\n",
    "- ```concat(tensor_list, c, axis)``` - Concatenates tensor list ```tensor_list``` along ```axis``` dimension.\n",
    "- ```matmul(t, m, c)``` - Multiplies tensor ```t``` by euclidean matrix ```m```.\n",
    "- ```add(x, y, c)``` - Adds tensor ```x``` and tensor ```y```.\n",
    "- ```add_bias(t, b, c)``` - Adds a euclidean bias vector ```b``` to tensor ```t```.\n",
    "- ```activation(t, c_in, c_out, act)``` - Computes the value of  activation function ```act``` for the input tensor ```t```.\n",
    "- ```linear(t, in_dim, out_dim, c_in, c_out, act, scope)``` - Computes the linear transformation for the input tensor ```t```.\n",
    "- ```distance(src, tar, c)``` - Computes the squared geodesic/distance between ```src``` and ```tar```.\n",
    "\n",
    "Complex operations can be decomposed into basic operations explicitly or realized in tangent space implicitly.\n",
    "\n",
    "Here we use two fully-connected layers as our model backbone."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages/curvlearn/manifolds/manifold.py:244: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages/curvlearn/manifolds/manifold.py:244: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages/curvlearn/manifolds/manifold.py:249: calling GlorotNormal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    }
   ],
   "source": [
    "x = manifold.concat([sparse_embedding, dense_embedding], axis=1, c=curvature)\n",
    "x = manifold.linear(x, 16, 256, curvature, curvature, tf.nn.elu, 'hidden_layer_1')\n",
    "x = manifold.linear(x, 256, 32, curvature, curvature, tf.nn.elu, 'hidden_layer_2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice non-euclidean geometry can only be expressed by geodesics, we use the fermi-dirac decoder to decode the distance and generate the probabilities. Cross entropy is used as the loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "origin = manifold.proj(tf.zeros([32], dtype=manifold.dtype), c=curvature)\n",
    "distance = tf.squeeze(manifold.distance(x, origin, c=curvature))\n",
    "loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=1.0 - 1.0*distance))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CurvLearn now supports the following optimizers.\n",
    "- ```curvlearn.optimizers.rsgd``` - Riemannian stochastic gradient optimizer.\n",
    "- ```curvlearn.optimizers.radagrad``` - Riemannian Adagrad optimizer.\n",
    "- ```curvlearn.optimizers.radam``` - Riemannian Adam optimizer.\n",
    "\n",
    "Here we apply riemannian adam optimizer to minimize the loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages/curvlearn/optimizers/rsgd.py:27: The name tf.train.GradientDescentOptimizer is deprecated. Please use tf.compat.v1.train.GradientDescentOptimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages/curvlearn/optimizers/radagrad.py:27: The name tf.train.AdagradOptimizer is deprecated. Please use tf.compat.v1.train.AdagradOptimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages/curvlearn/optimizers/radam.py:27: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/zrxu/opt/anaconda3/envs/dev/lib/python3.7/site-packages/tensorflow_core/python/framework/indexed_slices.py:424: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    }
   ],
   "source": [
    "from curvlearn.optimizers import RAdam\n",
    "optimizer = RAdam(learning_rate=learning_rate, manifold=manifold, c=curvature)\n",
    "train_op = optimizer.minimize(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now a non-Euclidean binary classification model is built successfully.\n",
    "\n",
    "Let's check the performance!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /var/folders/yb/bnmcr5gd40bc3s00cshy0gr80000gp/T/ipykernel_15969/312919670.py:1: The name tf.get_collection is deprecated. Please use tf.compat.v1.get_collection instead.\n",
      "\n",
      "WARNING:tensorflow:From /var/folders/yb/bnmcr5gd40bc3s00cshy0gr80000gp/T/ipykernel_15969/312919670.py:1: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.\n",
      "\n",
      "WARNING:tensorflow:From /var/folders/yb/bnmcr5gd40bc3s00cshy0gr80000gp/T/ipykernel_15969/312919670.py:4: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
      "\n",
      "WARNING:tensorflow:From /var/folders/yb/bnmcr5gd40bc3s00cshy0gr80000gp/T/ipykernel_15969/312919670.py:5: The name tf.local_variables_initializer is deprecated. Please use tf.compat.v1.local_variables_initializer instead.\n",
      "\n",
      "WARNING:tensorflow:From /var/folders/yb/bnmcr5gd40bc3s00cshy0gr80000gp/T/ipykernel_15969/312919670.py:6: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n",
      "\n",
      "WARNING:tensorflow:From /var/folders/yb/bnmcr5gd40bc3s00cshy0gr80000gp/T/ipykernel_15969/312919670.py:9: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-09-14 15:55:06.859444: I tensorflow/core/platform/cpu_feature_guard.cc:145] This TensorFlow binary is optimized with Intel(R) MKL-DNN to use the following CPU instructions in performance critical operations:  SSE4.1 SSE4.2\n",
      "To enable them in non-MKL-DNN operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2021-09-14 15:55:06.860084: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 8. Tune using inter_op_parallelism_threads for best performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No.1 batches, curvature 0.000999992829747498, loss 3.958437919616699\n",
      "No.101 batches, curvature 0.015423169359564781, loss 0.6913518905639648\n",
      "No.201 batches, curvature 0.015037328004837036, loss 0.6907403469085693\n",
      "No.301 batches, curvature 0.014309548772871494, loss 0.6823717951774597\n",
      "No.401 batches, curvature 0.013324601575732231, loss 0.6854066848754883\n",
      "No.501 batches, curvature 0.012068444862961769, loss 0.6762799024581909\n",
      "No.601 batches, curvature 0.010590470395982265, loss 0.6735092997550964\n",
      "No.701 batches, curvature 0.008658314123749733, loss 0.6782102584838867\n",
      "No.801 batches, curvature 0.006072953809052706, loss 0.6783410310745239\n",
      "No.901 batches, curvature 0.0031166779808700085, loss 0.6705145835876465\n",
      "No.1001 batches, curvature -0.00010512575681786984, loss 0.6698893308639526\n",
      "No.1101 batches, curvature -0.003452387172728777, loss 0.6701076626777649\n",
      "No.1201 batches, curvature -0.007259138859808445, loss 0.6591002941131592\n",
      "No.1301 batches, curvature -0.010397176258265972, loss 0.6515015363693237\n",
      "No.1401 batches, curvature -0.013047800399363041, loss 0.6649119853973389\n",
      "No.1501 batches, curvature -0.014556215144693851, loss 0.6390271186828613\n",
      "No.1601 batches, curvature -0.0157223641872406, loss 0.6614259481430054\n",
      "No.1701 batches, curvature -0.01662936434149742, loss 0.6651350855827332\n",
      "No.1801 batches, curvature -0.01760878972709179, loss 0.6602005362510681\n",
      "No.1901 batches, curvature -0.017814118415117264, loss 0.6431636214256287\n",
      "No.2001 batches, curvature -0.019135411828756332, loss 0.6432427167892456\n",
      "No.2101 batches, curvature -0.01914401911199093, loss 0.6454476118087769\n",
      "No.2201 batches, curvature -0.020368380472064018, loss 0.6371751427650452\n",
      "No.2301 batches, curvature -0.020331185311079025, loss 0.6319084763526917\n",
      "No.2401 batches, curvature -0.020242569968104362, loss 0.6406416296958923\n",
      "No.2501 batches, curvature -0.020820366218686104, loss 0.6401246786117554\n",
      "No.2601 batches, curvature -0.020268552005290985, loss 0.6385504007339478\n",
      "No.2701 batches, curvature -0.020809054374694824, loss 0.6374597549438477\n",
      "No.2801 batches, curvature -0.020833346992731094, loss 0.6265876889228821\n",
      "No.2901 batches, curvature -0.021381892263889313, loss 0.6239088773727417\n",
      "No.3001 batches, curvature -0.020717168226838112, loss 0.6319794058799744\n",
      "No.3101 batches, curvature -0.019755186513066292, loss 0.6216168403625488\n",
      "No.3201 batches, curvature -0.020886586979031563, loss 0.6256126165390015\n",
      "No.3301 batches, curvature -0.02243652753531933, loss 0.6150191426277161\n",
      "No.3401 batches, curvature -0.02182244323194027, loss 0.6224449276924133\n",
      "No.3501 batches, curvature -0.02230178937315941, loss 0.6291787624359131\n",
      "No.3601 batches, curvature -0.021937668323516846, loss 0.6027733683586121\n",
      "No.3701 batches, curvature -0.022809172049164772, loss 0.6155853867530823\n",
      "No.3801 batches, curvature -0.022843029350042343, loss 0.6215240955352783\n",
      "No.3901 batches, curvature -0.02289605885744095, loss 0.6152886152267456\n",
      "No.4001 batches, curvature -0.024716004729270935, loss 0.6186487674713135\n",
      "No.4101 batches, curvature -0.02274315431714058, loss 0.6204020977020264\n",
      "No.4201 batches, curvature -0.024335019290447235, loss 0.6221641898155212\n",
      "No.4301 batches, curvature -0.025296980515122414, loss 0.5852887630462646\n",
      "No.4401 batches, curvature -0.023965677246451378, loss 0.5905001163482666\n",
      "No.4501 batches, curvature -0.024087978526949883, loss 0.6077329516410828\n",
      "No.4601 batches, curvature -0.025763778015971184, loss 0.606231689453125\n",
      "No.4701 batches, curvature -0.026451654732227325, loss 0.5950381755828857\n",
      "No.4801 batches, curvature -0.028369437903165817, loss 0.6031137704849243\n",
      "No.4901 batches, curvature -0.028076142072677612, loss 0.573093831539154\n",
      "Finish train\n"
     ]
    }
   ],
   "source": [
    "ops = [train_op, curvature, loss] + tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n",
    "\n",
    "batch_idx = 0\n",
    "global_init = tf.global_variables_initializer()\n",
    "local_init = tf.local_variables_initializer()\n",
    "cp = tf.ConfigProto()\n",
    "cp.gpu_options.allow_growth = True\n",
    "\n",
    "with tf.Session(config=cp) as sess:\n",
    "    sess.run([global_init, local_init])\n",
    "    while True:\n",
    "        try:\n",
    "            batch_idx += 1\n",
    "            _, c, loss = sess.run(ops)\n",
    "            if batch_idx % log_steps == 1:\n",
    "                print('No.{} batches, curvature {}, loss {}'.format(batch_idx, c, loss))\n",
    "\n",
    "        except tf.errors.OutOfRangeError:\n",
    "            print('Finish train')\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since our dataset is generated without any geometry prior, the curvature is trained to be near zero and the space is almost euclidean. \n",
    "\n",
    "Check performance on real dataset([recommendation](hyperml/README.md), [link prediction](hgcn/README.md), [tree pretrain](tree_pretrain/README.md)) and see the advantages of non-euclidean geometry."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
